{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.vision import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = Config().data_path()/'imagenet'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inputs: precomputed activations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will build a model on the whole of ImageNet, so we will compute once and for all the activations for the whole training and validation set. We use `presize` to set the images to 224 x 224 by just using PIL."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "src = (ImageList.from_folder(path)\n",
    "          .split_by_folder()\n",
    "          .label_from_folder())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = src.presize(size=224).databunch(bs=256).normalize(imagenet_stats)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the pretrained resnet50 with concat pool and flatten:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "body = create_body(models.resnet50)\n",
    "layers = list(body.children())\n",
    "layers += [AdaptiveConcatPool2d(), Flatten()]   \n",
    "body = nn.Sequential(*layers).to(defaults.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use bcolz to store our activations in an array that's saved to memory (all won't fit in RAM). Install with \n",
    "```\n",
    "pip install -U bcolz\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import bcolz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp_path = path/'tmp'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#To clean-up previous tries\n",
    "#shutil.rmtree(tmp_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Those functions will store the precomputed activations in `tmp_path`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def precompute_activations_dl(dl, model, path:Path, force:bool=False):\n",
    "    model.eval()\n",
    "    if os.path.exists(path) and not force: return\n",
    "    arr = bcolz.carray(np.zeros((0,4096), np.float32), chunklen=1, mode='w', rootdir=path)\n",
    "    with torch.no_grad():\n",
    "        for x,y in progress_bar(dl):\n",
    "            z = model(x)\n",
    "            arr.append(z.cpu().numpy())\n",
    "            arr.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def precompute_activations(data, model, path:Path, force:bool=False):\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    precompute_activations_dl(data.fix_dl,   model, path/'train', force=force) #Use fix_dl and not train_dl for shuffle=False\n",
    "    precompute_activations_dl(data.valid_dl, model, path/'valid', force=force)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "precompute_activations(data, body, tmp_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the labels and the filenames in the same order as our activations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(tmp_path/'trn_lbl.npy', data.train_ds.y.items)\n",
    "np.save(tmp_path/'val_lbl.npy', data.valid_ds.y.items)\n",
    "save_texts(tmp_path/'classes.txt', data.train_ds.classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(tmp_path/'trn_names.npy', data.train_ds.x.items)\n",
    "np.save(tmp_path/'val_names.npy', data.valid_ds.x.items)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To load our precomputed activations, we'll use the following `ItemList`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BcolzItemList(ItemList):\n",
    "    def __init__(self, path, **kwargs):\n",
    "        self.arr = bcolz.open(path)\n",
    "        super().__init__(range(len(self.arr)), **kwargs)\n",
    "    \n",
    "    def get(self, i): return self.arr[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "src = ItemLists(path, BcolzItemList(path/'tmp'/'train'), BcolzItemList(path/'tmp'/'valid'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Targets: word vectors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We build a regression model that has to predict a vector from the image features. We need to associate a word vector to each one of our 1000 classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = loadtxt_str(tmp_path/'classes.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classes, len(classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The labels in imagenet are codes that come from [wordnet](https://wordnet.princeton.edu/). So let's download the corresponding dictionary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "WORDNET = 'classids.txt'\n",
    "download_url(f'http://files.fast.ai/data/{WORDNET}', path/'tmp'/WORDNET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_ids = loadtxt_str(path/'tmp'/WORDNET)\n",
    "class_ids = dict([l.strip().split() for l in class_ids])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "named_classes = [class_ids[c] for c in classes]\n",
    "named_classes[:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will train our model to predict not the label of its class, but the corresponding pretrained vector. There are plenty of word embeddings available, here we will use fastText.\n",
    "\n",
    "To install fastText:\n",
    "```\n",
    "$ git clone https://github.com/facebookresearch/fastText.git\n",
    "$ cd fastText\n",
    "$ pip install .\n",
    "```\n",
    "\n",
    "To download the english embeddings:\n",
    "\n",
    "```\n",
    "$ wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import fasttext as ft\n",
    "en_vecs = ft.load_model(str((path/'cc.en.300.bin')))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A lot of our classes are actually composed of several words separated by a `_`. The pretrained word vectors from fastText won't know them directly, but it can still compute a word vector to represent them:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vec_dog = en_vecs.get_sentence_vector('dog')\n",
    "vec_lab = en_vecs.get_sentence_vector('labrador')\n",
    "vec_gor = en_vecs.get_sentence_vector('golden retriever')\n",
    "vec_ban = en_vecs.get_sentence_vector('banana')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To check if two word vectors are close or not, we use cosine similarity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "F.cosine_similarity(tensor(vec_dog[None]), tensor(vec_lab[None]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "F.cosine_similarity(tensor(vec_dog[None]), tensor(vec_ban[None]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "F.cosine_similarity(tensor(vec_lab[None]), tensor(vec_ban[None]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "F.cosine_similarity(tensor(vec_dog[None]), tensor(vec_gor[None]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "F.cosine_similarity(tensor(vec_lab[None]), tensor(vec_gor[None]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So let's grab all the word vectors for all our classes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vecs = []\n",
    "for n in named_classes:\n",
    "    vecs.append(en_vecs.get_sentence_vector(n.replace('_', ' ')))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we label each feature map with the word vector of its target."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_labels = np.load(tmp_path/'trn_lbl.npy')\n",
    "valid_labels = np.load(tmp_path/'val_lbl.npy')\n",
    "train_vecs = [vecs[l] for l in train_labels]\n",
    "valid_vecs = [vecs[l] for l in valid_labels]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use our custom `BcolzItemList` to gather the data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "src = ItemLists(path, BcolzItemList(tmp_path/'train'), BcolzItemList(tmp_path/'valid'))\n",
    "src = src.label_from_lists(train_vecs, valid_vecs, label_cls=FloatList)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = src.databunch(bs=512, num_workers=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = create_head(4096, data.c, lin_ftrs = [1024], ps=[0.2,0.2])\n",
    "model = nn.Sequential(*list(model.children())[2:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cos_loss(inp,targ): return 1 - F.cosine_similarity(inp,targ).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(data, model, loss_func=cos_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.lr_find()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.recorder.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit_one_cycle(15,3e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.save('fit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.model.eval()\n",
    "preds = []\n",
    "with torch.no_grad():\n",
    "    for x,y in progress_bar(learn.data.fix_dl):\n",
    "        preds.append(learn.model(x).cpu().numpy())\n",
    "    for x,y in progress_bar(learn.data.valid_dl):\n",
    "        preds.append(learn.model(x).cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = np.concatenate(preds, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(path/'preds.npy', preds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Looking at predicted tags in image classes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will check, for one given image, what are the word vectors that are the closes to it. To compute this very quickly, we use `nmslib` which is very fast (pip install nmslib)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nmslib\n",
    "\n",
    "def create_index(a):\n",
    "    index = nmslib.init(space='angulardist')\n",
    "    index.addDataPointBatch(a)\n",
    "    index.createIndex()\n",
    "    return index\n",
    "\n",
    "def get_knns(index, vecs):\n",
    "     return zip(*index.knnQueryBatch(vecs, k=10, num_threads=4))\n",
    "\n",
    "def get_knn(index, vec): return index.knnQuery(vec, k=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first look in the word vectors of our given classes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn_classes = create_index(vecs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_preds = preds[-len(data.valid_ds):]\n",
    "valid_names = np.load(tmp_path/'val_names.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idxs,dists = get_knns(nn_classes, valid_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ks = [0,10000,20000,30000]\n",
    "_,axs = plt.subplots(2,2,figsize=(12,8))\n",
    "for k,ax in zip(ks, axs.flatten()):\n",
    "    open_image(valid_names[k]).show(ax = ax)\n",
    "    title = ','.join([class_ids[classes[i]] for i in idxs[k][:3]])\n",
    "    title += f'\\n{class_ids[classes[valid_labels[k]]]}'\n",
    "    ax.set_title(title)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Looking at predicted tags in all Wordnet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's look at the words it finds in all Wordnet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "words,wn_vecs = [],[]\n",
    "for k,n in class_ids.items():\n",
    "    words.append(n)\n",
    "    wn_vecs.append(en_vecs.get_sentence_vector(n.replace('_', ' ')))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn_wvs = create_index(wn_vecs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idxs,dists = get_knns(nn_wvs, valid_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ks = [0,10000,20000,30000]\n",
    "_,axs = plt.subplots(2,2,figsize=(12,8))\n",
    "for k,ax in zip(ks, axs.flatten()):\n",
    "    open_image(valid_names[k]).show(ax = ax)\n",
    "    title = ','.join([words[i] for i in idxs[k][:3]])\n",
    "    title += f'\\n{class_ids[classes[valid_labels[k]]]}'\n",
    "    ax.set_title(title)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text -> Image search"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use the reverse approach: feed a word vector and find the image activations that match it the closest:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn_preds = create_index(valid_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_imgs_from_text(text):\n",
    "    vec = en_vecs.get_sentence_vector(text)\n",
    "    idxs,dists = get_knn(nn_preds, vec)\n",
    "    _,axs = plt.subplots(2,2,figsize=(12,8))\n",
    "    for i,ax in zip(idxs[:4], axs.flatten()):\n",
    "        open_image(valid_names[i]).show(ax = ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "'boat' isn't a label in ImageNet, yet if we ask the images whose vord vectors are the most similar to the word vector for boat..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_imgs_from_text('boat')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "or even more precisely 'motor boat'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_imgs_from_text('motor boat')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_imgs_from_text('sail boat')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Image->image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also ask for the images with a word vector most similar to another image. This one was downloaded from Google and isn't in Imagenet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = open_image('images/teddy_bear.jpg')\n",
    "img"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To get the corresdponding vector, we need to feed it to the pretrained model (`body`, defined at the top) after normalizing it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = img.data\n",
    "m,s = imagenet_stats\n",
    "x = (img - tensor(m)[:,None,None])/tensor(s)[:,None,None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activs = body.eval()(x[None].cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = learn.model.eval()(activs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = pred[0].detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idxs,dists = get_knn(nn_preds, pred)\n",
    "_,axs = plt.subplots(2,2,figsize=(12,8))\n",
    "for i,ax in zip(idxs[:4], axs.flatten()):\n",
    "    open_image(valid_names[i]).show(ax = ax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
